{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dualing CNN DQN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import gym\n",
    "\n",
    "import numpy as np\n",
    "import random\n",
    "from collections import namedtuple, deque\n",
    "from itertools import count\n",
    "import time\n",
    "import datetime\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "import pdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:This caffe2 python run does not have GPU support. Will run in CPU only mode.\n"
     ]
    }
   ],
   "source": [
    "from src.utils.Config import Config\n",
    "from src.utils.Logging import Logger\n",
    "from src.utils.atari_wrappers import wrap_deepmind, make_atari"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated.  Call .resolve and .require separately.\n",
      "  result = entry_point.load(False)\n"
     ]
    }
   ],
   "source": [
    "seed = 1234\n",
    "\n",
    "env_id = \"PongNoFrameskip-v4\"\n",
    "env    = make_atari(env_id)\n",
    "env    = wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=True, scale=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "observation shape: (84, 84, 4)\n",
      "Action space: 6\n"
     ]
    }
   ],
   "source": [
    "print(\"observation shape: {}\".format(env.observation_space.shape))\n",
    "print(\"Action space: {}\".format(env.action_space.n))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parsing state from lazy frame to torch tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_state(state):\n",
    "    state = np.array(state).transpose((2, 0, 1))\n",
    "    state = torch.from_numpy(state)\n",
    "    state = state.unsqueeze(0).float() / 256\n",
    "    return state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dualing CNN Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DualingCNNQNetwork(nn.Module):\n",
    "    def __init__(self, channel_in, action_size, seed, fc1_units=512, fc2_unit=512):\n",
    "        # Call inheritance\n",
    "        super(DualingCNNQNetwork, self).__init__()\n",
    "        self.seed = torch.manual_seed(1234)\n",
    "        \n",
    "        self.conv_head = nn.Sequential(\n",
    "            nn.Conv2d(4, 32, kernel_size=8, stride=4),\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(32, 64, kernel_size=4, stride=2),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(64, 64, kernel_size=3, stride=1),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "        \n",
    "\n",
    "        self.advantage = nn.Sequential(\n",
    "            nn.Linear(64 * 7 * 7, fc1_units),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(fc1_units, action_size)\n",
    "        )\n",
    "\n",
    "        self.value = nn.Sequential(\n",
    "            nn.Linear(64 * 7 * 7, fc1_units),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(fc1_units, 1)\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        x = self.conv_head(state)\n",
    "        x = x.view(x.size(0), -1) # Flatten\n",
    "        advantage = self.advantage(x)\n",
    "        value = self.value(x)\n",
    "\n",
    "        x = value + advantage - advantage.mean()\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReplayBuffer:\n",
    "    \"\"\"Fixed-size buffer to store experience tuples.\"\"\"\n",
    "\n",
    "    def __init__(self, config):\n",
    "        \"\"\"Initialize a ReplayBuffer object.\n",
    "\n",
    "        Params\n",
    "        ======\n",
    "        action_size (int): dimension of each action\n",
    "        buffer_size (int): maximum size of buffer\n",
    "        batch_size (int): size of each training batch\n",
    "        seed (int): random seed\n",
    "        \"\"\"\n",
    "\n",
    "        self.memory = deque(maxlen=config.buffer_size)  \n",
    "        self.batch_size = config.batch_size\n",
    "        self.experience = namedtuple(\"Experience\", field_names=[\"state\", \"action\", \"reward\", \"next_state\", \"done\"])\n",
    "        self.seed = random.seed(config.seed)\n",
    "        self.device = config.device\n",
    "\n",
    "    def add(self, state, action, reward, next_state, done, error=0):\n",
    "        \"\"\"Add a new experience to memory.\"\"\"\n",
    "        e = self.experience(state, action, reward, next_state, done)\n",
    "        self.memory.append(e)\n",
    "\n",
    "    def sample(self):\n",
    "        \"\"\"Randomly sample a batch of experiences from memory.\"\"\"\n",
    "        experiences = random.sample(self.memory, k=self.batch_size)\n",
    "        \n",
    "#         pdb.set_trace()\n",
    "\n",
    "        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(self.device)\n",
    "        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(self.device)\n",
    "        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device)\n",
    "        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(self.device)\n",
    "        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(self.device)\n",
    "        weights = torch.ones(len(experiences)).float().to(self.device)\n",
    "        memory_loc = torch.zeros(len(experiences)).float().to(self.device)\n",
    "\n",
    "        return (states, actions, rewards, next_states, dones, weights, memory_loc)\n",
    "\n",
    "    def n_entries(self):\n",
    "        return len(self.memory)\n",
    "\n",
    "    def __len__(self):\n",
    "        \"\"\"Return the current size of internal memory.\"\"\"\n",
    "        return len(self.memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Agent():\n",
    "    def __init__(self, state_size, action_size, config):\n",
    "        \"\"\"\n",
    "        Initialize an Agent object.\n",
    "\n",
    "        state_size (int): dimension of each state\n",
    "        action_size (int): dimension of each action\n",
    "        \"\"\"\n",
    "\n",
    "        if config.model is None:\n",
    "            raise Exception(\"Please select a Model for agent\")\n",
    "        if config.memory is None:\n",
    "            raise Exception(\"Please select Memory for agent\") \n",
    "\n",
    "        self.config = config\n",
    "\n",
    "        self.state_size = state_size\n",
    "        self.action_size = action_size\n",
    "        self.seed = random.seed(self.config.seed)\n",
    "\n",
    "        # Q-Network\n",
    "        self.qnetwork_local = self.config.model(state_size, action_size, self.seed, 64, 64).to(self.config.device)\n",
    "        self.qnetwork_target = self.config.model(state_size, action_size, self.seed, 64, 64).to(self.config.device)\n",
    "\n",
    "        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.config.lr)\n",
    "\n",
    "        # Replay memory\n",
    "        self.memory = self.config.memory(config)\n",
    "\n",
    "        # Initialize time step (for updating every UPDATE_EVERY steps)\n",
    "        self.t_step = 0\n",
    "        \n",
    "        # Keep \n",
    "        self.eps = self.config.eps_start\n",
    "    \n",
    "    def step(self, state, action, reward, next_state, done):\n",
    "        state = self.parse_state(state)\n",
    "        next_state = self.parse_state(next_state)\n",
    "        \n",
    "        # Save experience in replay memory\n",
    "        self.memory.add(state, action, reward, next_state, done) \n",
    "\n",
    "        # Learn every UPDATE_EVERY time steps.\n",
    "        self.t_step = (self.t_step + 1) % self.config.learn_every\n",
    "\n",
    "        if self.t_step == 0:\n",
    "            # If enough samples are available in memory, get random subset and learn\n",
    "            if self.memory.n_entries() > self.config.batch_size:\n",
    "                experiences = self.memory.sample()\n",
    "                self.learn(experiences, self.config.gamma)\n",
    "\n",
    "    def anneal_eps(self):\n",
    "        self.eps = max(self.config.eps_end, self.config.eps_decay*self.eps) \n",
    "    \n",
    "    def parse_state(self, state):\n",
    "        state = np.array(state).transpose((2, 0, 1))\n",
    "        state = torch.from_numpy(state)\n",
    "        state = state.unsqueeze(0).float() / 256\n",
    "        return state\n",
    "    \n",
    "    def act(self, state, network_only=False):\n",
    "        state = self.parse_state(state).to(self.config.device)\n",
    "\n",
    "        self.qnetwork_local.eval()\n",
    "        with torch.no_grad():\n",
    "            action_values = self.qnetwork_local(state)\n",
    "        self.qnetwork_local.train()\n",
    "        \n",
    "        # Epsilon-greedy action selection\n",
    "        if random.random() > self.eps:\n",
    "            return np.argmax(action_values.cpu().data.numpy())\n",
    "        else:\n",
    "            return random.choice(np.arange(self.action_size))\n",
    "    \n",
    "    def learn(self, experiences, gamma):\n",
    "        \"\"\"Update value parameters using given batch of experience tuples.\n",
    "\n",
    "        Params\n",
    "        ======\n",
    "        experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples \n",
    "        gamma (float): discount factor\n",
    "        \"\"\"\n",
    "        states, actions, rewards, next_states, dones, is_weights, idxs = experiences\n",
    "\n",
    "        # Get expected Q values from local model\n",
    "        Q_expected = self.qnetwork_local(states).gather(1, actions)\n",
    "\n",
    "\n",
    "        # Get max action from network\n",
    "        max_next_actions = self.get_max_next_actions(next_states)\n",
    "\n",
    "\n",
    "        # Get max_next_q_values -> .gather(\"dim\", \"index\")\n",
    "        max_next_q_values = self.qnetwork_target(next_states).gather(1, max_next_actions)\n",
    "\n",
    "        # Y^q\n",
    "        Q_targets = rewards + (gamma * max_next_q_values * (1 - dones))\n",
    "\n",
    "        loss = F.mse_loss(Q_expected, Q_targets)\n",
    "\n",
    "        # Minimize the loss\n",
    "        # - Optimizer is initilaised with qnetwork_local, so will update that one.\n",
    "        self.optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        self.optimizer.step()\n",
    "\n",
    "        # ------------------- update target network ------------------- #\n",
    "        self.soft_update(self.qnetwork_local, self.qnetwork_target, self.config.tau)\n",
    "    \n",
    "    def soft_update(self, local_model, target_model, tau):\n",
    "        \"\"\"Soft update model parameters.\n",
    "        θ_target = τ*θ_local + (1 - τ)*θ_target\n",
    "\n",
    "        Params\n",
    "        ======\n",
    "        local_model (PyTorch model): weights will be copied from\n",
    "        target_model (PyTorch model): weights will be copied to\n",
    "        tau (float): interpolation parameter \n",
    "        \"\"\"\n",
    "        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):\n",
    "            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)\n",
    "    \n",
    "    def get_max_next_actions(self, next_states):\n",
    "        '''\n",
    "        Passing next_states through network will give array of all action values (in batches)\n",
    "        eg: [[0.25, -0.35], [-0.74, -0.65], ...]\n",
    "        - Detach will avoid gradients being calculated on the variables\n",
    "        - max() gives max value in whole array (0.25)\n",
    "        - max(0) gives max in dim=0 ([0.25, -0.35])\n",
    "        - max(1) gives max in dim=1, therefore: two tensors, one of max values and one of index\n",
    "        eg: \n",
    "        - values=tensor([0.25, -0.65, ...])\n",
    "        - indices=tensor([0, 1, ...])\n",
    "        - we'll want to take the [1] index as that is only the action\n",
    "        eg: tensor([0, 1, ...]) \n",
    "        - unsqueeze allows us to create them into an array that is usuable for next bit (.view(-1,1) would also work)\n",
    "        eg: eg: tensor([[0], [1], ...]) \n",
    "        '''\n",
    "        return self.qnetwork_target(next_states).detach().max(1)[1].unsqueeze(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agent Configuration:\n",
      "env: \t\tEnvSpec(PongNoFrameskip-v4)\n",
      "win condition: \t10000\n",
      "device: \tcpu\n",
      "seed: \t\t123456789\n",
      "n_episodes: \t1300\n",
      "max_t: \t\t1000\n",
      "eps_start: \t1.0\n",
      "eps_end: \t0.01\n",
      "eps_decay: \t0.995\n",
      "eps_greedy: \tTrue\n",
      "noisy: \t\tFalse\n",
      "tau: \t\t0.001\n",
      "gamma: \t\t0.99\n",
      "lr: \t\t0.0005\n",
      "memory: \t<class '__main__.ReplayBuffer'>\n",
      "batch_size: \t32\n",
      "buffer_size: \t100000\n",
      "lr_annealing: \tFalse\n",
      "learn_every: \t4\n",
      "double_dqn: \tFalse\n",
      "model: \t\t<class '__main__.DualingCNNQNetwork'>\n",
      "save_loc: \tNone\n",
      "<_sre.SRE_Match object; span=(0, 27), match='EnvSpec(PongNoFrameskip-v4)'>\n",
      "Logging at: logs/PongNoFrameskip-v4/experiment-2020-05-25_16_40_39\n"
     ]
    }
   ],
   "source": [
    "config = Config()\n",
    "\n",
    "config.env = env\n",
    "\n",
    "config.win_condition = 10000\n",
    "config.n_episodes = 1300\n",
    "config.memory = ReplayBuffer\n",
    "config.batch_size = 32\n",
    "config.model = DualingCNNQNetwork\n",
    "config.print_config()\n",
    "\n",
    "logger = Logger(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(config, logger):\n",
    "    experiment_start = time.time()\n",
    "    env = config.env\n",
    "    frame = 0\n",
    "\n",
    "    agent = Agent(4, action_size=env.action_space.n, config=config)\n",
    "\n",
    "    total_scores_deque = deque(maxlen=100)\n",
    "    total_scores = []\n",
    "\n",
    "\n",
    "    for i_episode in range(1, config.n_episodes+1):\n",
    "        states = env.reset()\n",
    "        scores = 0\n",
    "\n",
    "        start_time = time.time()\n",
    "\n",
    "        for t in count():\n",
    "            actions = agent.act(states)\n",
    "            next_states, rewards, dones, _ = env.step(np.array(actions))\n",
    "            agent.step(states, actions, rewards, next_states,dones)\n",
    "\n",
    "\n",
    "\n",
    "            states = next_states\n",
    "            scores += rewards\n",
    "            number_of_time_steps = t\n",
    "            frame += 1\n",
    "\n",
    "            if np.any(dones):\n",
    "                break \n",
    "\n",
    "        agent.anneal_eps()\n",
    "\n",
    "\n",
    "        # Book Keeping\n",
    "        mean_score = np.mean(scores)\n",
    "        min_score = np.min(scores)\n",
    "        max_score = np.max(scores)\n",
    "\n",
    "        total_scores_deque.append(mean_score)\n",
    "        total_scores.append(mean_score)\n",
    "        total_average_score = np.mean(total_scores_deque)\n",
    "\n",
    "        logger.log_scalar(\"score\", mean_score, i_episode)\n",
    "        logger.log_scalar(\"average_score\", total_average_score, i_episode)\n",
    "\n",
    "        duration = time.time() - start_time\n",
    "\n",
    "        # print('\\rEpisode {}\\tTotal Average Score (in 100 window): {:.4f}\\tMean: {:.4f}\\tMin: {:.4f}\\tMax: {:.4f}\\tDuration: {:.4f}\\t#TimeSteps: {:.4f}'.format(i_episode, total_average_score, mean_score, min_score, max_score, duration, number_of_time_steps), end=\"\")\n",
    "        print('\\rEpi: {}\\t Frame: {} \\tAverage: {:.4f}\\tMean: {:.4f}\\tDuration: {:.2f}\\t#t_s: {:.1f}'.format(i_episode, frame, total_average_score, mean_score, duration, number_of_time_steps), end=\"\")\n",
    "\n",
    "        if i_episode % 100 == 0:\n",
    "            print('\\rEpi: {}\\t Frame: {}\\tAverage Score: {:.4f}\\tMean: {:.4f}\\tDuration: {:.2f}\\t#t_s: {:.1f}'.format(i_episode, frame, total_average_score, mean_score, duration, number_of_time_steps))\n",
    "            torch.save(agent.qnetwork_local.state_dict(), \"{}/checkpoint.pth\".format(logger.log_file_path, date=datetime.datetime.now()))\n",
    "        if config.win_condition is not None and total_average_score > config.win_condition: \n",
    "            print(\"\\nEnvironment Solved in {:.4f} seconds !\".format(time.time() - experiment_start))\n",
    "            torch.save(agent.qnetwork_local.state_dict(), \"{}/checkpoint.pth\".format(logger.log_file_path, date=datetime.datetime.now()))\n",
    "            return \n",
    "\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epi: 100\t Frame: 91883\tAverage Score: -20.3900\tMean: -21.0000\tDuration: 8.66\t#t_s: 820.0\n",
      "Epi: 200\t Frame: 186548\tAverage Score: -20.0700\tMean: -20.0000\tDuration: 11.03\t#t_s: 1077.0\n",
      "Epi: 300\t Frame: 317441\tAverage Score: -17.6300\tMean: -15.0000\tDuration: 16.11\t#t_s: 1558.0\n",
      "Epi: 400\t Frame: 471867\tAverage Score: -14.8700\tMean: -11.0000\tDuration: 20.16\t#t_s: 1957.0\n",
      "Epi: 500\t Frame: 635742\tAverage Score: -13.0300\tMean: -13.0000\tDuration: 17.94\t#t_s: 1735.0\n",
      "Epi: 600\t Frame: 854036\tAverage Score: -2.2200\tMean: 8.0000\tDuration: 24.72\t#t_s: 2394.0\n",
      "Epi: 700\t Frame: 1082486\tAverage Score: 9.3500\tMean: 16.0000\tDuration: 20.04\t#t_s: 1936.0\n",
      "Epi: 800\t Frame: 1290177\tAverage Score: 13.5900\tMean: 17.0000\tDuration: 19.76\t#t_s: 1907.0\n",
      "Epi: 900\t Frame: 1488982\tAverage Score: 14.9700\tMean: 18.0000\tDuration: 18.77\t#t_s: 1831.0\n",
      "Epi: 1000\t Frame: 1684547\tAverage Score: 16.0800\tMean: 21.0000\tDuration: 16.96\t#t_s: 1648.0\n",
      "Epi: 1100\t Frame: 1873532\tAverage Score: 17.3400\tMean: 17.0000\tDuration: 19.79\t#t_s: 1936.0\n",
      "Epi: 1200\t Frame: 2060107\tAverage Score: 18.0700\tMean: 20.0000\tDuration: 17.45\t#t_s: 1694.0\n",
      "Epi: 1300\t Frame: 2248499\tAverage Score: 17.7500\tMean: 20.0000\tDuration: 17.65\t#t_s: 1703.0\n"
     ]
    }
   ],
   "source": [
    "train(config, logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(logger.score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "watch(config, logger.log_file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
